'''
Description: this code for evaluating ocr accuracy.
Author: KouXichao
Date: 2020-10-31 21:14:59
LastEditors: KouXichao
LastEditTime: 2020-11-18 19:58:06
FilePath: \PaddleOCR\eval_rec_acc_images.py
Github: https://github.com/kouxichao
'''

import os
import argparse
import numpy as np
import cv2
import Levenshtein
import math
import shutil  
import re
from zhon.hanzi import punctuation 
import string

class SynthLayout():
    '''
    description: 合成文档数据处理
    param {*} layout_folder: 合成文档数据集路径
    '''
    def __init__(self, layout_folder, warp=False):
        super(SynthLayout, self).__init__()
        self.layout_folder = layout_folder
        self.warp = warp
        self.images = np.load(layout_folder + 'image_names.npy')
        self.images = self.images.tolist()
        self.punc = punctuation + string.punctuation  
        self.all_bboxes = []
        for index in range(len(self.images)):
            bboxes = np.load(self.layout_folder + 'regionInfo/' + self.images[index].split('.')[0] + '/bboxes.npy', allow_pickle=True)
            self.all_bboxes.append(bboxes)
        
        self.line_image_path = os.path.join(self.layout_folder, "crop_line_images/")
        if os.path.exists(self.line_image_path):  
            shutil.rmtree(self.line_image_path)
        if os.path.exists(os.path.join(self.layout_folder, 'line_image_names.txt')):
            os.remove(os.path.join(self.layout_folder, 'line_image_names.txt')) 
        os.mkdir(self.line_image_path)  
	
    def __len__(self):
        '''
        :return: 图像的数量
        '''
        return len(self.images)
    

    def get_imagename(self, index):
        ''' 
        :param: index 图像索引，范围为0-len(self.images)
        :return: 图像文件名（包含后缀）
        '''
        return self.images[index]

    def load_image_gt(self, index):
        '''
        :param index: 图像索引，范围为0-len(self.images).
        :return: 图像数据(numpy)，字符的框，词的框，文本行的框，图像名.
        '''

        assert index >= 0 and index < len(self.images)    
        img_path = os.path.join(self.layout_folder, 'image/' + self.images[index])
        #print("********* load image ******* ", img_path)
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # mask_path = os.path.join(self.layout_folder, 'image/mask_' + self.images[index])
        # image_mask = cv2.imread(mask_path, 0)#, cv2.IMREAD_COLOR)

        bboxes = self.all_bboxes[index]
        word_bboxes = bboxes[0]
        line_bboxes = bboxes[1]
        char_bboxes = bboxes[2]    
        # return image, char_bboxes, word_bboxes, np.ones((image.shape[0], image.shape[1]), np.float32), [], line_bboxes, self.image[index].split(".")[0] 

        return image, char_bboxes, word_bboxes, line_bboxes, self.images[index].split(".")[0] 

    def crop_image_line(self, index):
        '''
        description: 
            获取图像中的所有文本行图像及其文本信(目前只支持直线文本). TODO:扭曲文本行裁剪
        param
            index(int): 文档图像索引
            warp(bool): 是否扭曲  
        return: 
            imagesOfLine(list): 文本行图像列表
            textOfLine(list): 行文本列表
        '''
         
        image, _, word_bboxes, line_bboxes, imageName = dataloader.load_image_gt(index)
        imageName = imageName + '_warp.jpg' if self.warp else imageName + '.jpg'
        imagesOfLine = [] # 文本行图像列表
        textOfLine = []   # 行文本列表
        line_image_name_list = []
        textOfLineWithoutPunc = []
        for i in range(line_bboxes.shape[0]):     
            text = ""                # 行文本
            text_without_punc = ""
            for wb in word_bboxes[i]:  # 提取词语文字并合并为行文本
                text += wb[-1]
                word = re.sub(r"[%s]+" %(self.punc), "", wb[-1])
                if text_without_punc is "":
                    text_without_punc += word
                else:
                    text_without_punc += ' ' + word
            cropped = image[line_bboxes[i][1]:line_bboxes[i][5], line_bboxes[i][0]:line_bboxes[i][2], :]

                    
            # print("text of line: \n", text)
            # cv2.imshow("line image", cropped)
            # cv2.waitKey()
            text_without_punc = ' '.join(text_without_punc.split())
            text_without_punc = text_without_punc.strip()
            text = text.strip()
            line_image_name = os.path.join(self.line_image_path, str(index) + "_" + str(i) + "line.jpg")
            line_image_name_list.append(line_image_name)
            # cv2.imwrite(line_image_name, cropped)
            with open(os.path.join(self.layout_folder, "line_image_names.txt"), 'a') as f:
                f.write(line_image_name)
                f.write("\n")            
            imagesOfLine.append(cropped)  
            textOfLine.append(text)  
            textOfLineWithoutPunc.append(text_without_punc)     
        return imagesOfLine, textOfLine, textOfLineWithoutPunc, line_image_name_list

# 导入识别模块
import tools.infer.predict_rec as predict_rec
import tools.infer.utility as utility

if __name__ == '__main__':
    args = utility.parse_args()
    dataloader = SynthLayout(args.image_dir)
    nums = dataloader.__len__()  
        
    dist = 0       
    dist_without_space = 0
    dist_without_punc = 0     
    total_char_num = 0
    total_char_num_without_punc = 0
    for i in range(nums): 
        # data for ocr
        imagesOfLine, textOfLine, text_without_punc, line_image_name_list = dataloader.crop_image_line(i)
        # recognization
        textRecognizer = predict_rec.TextRecognizer(args)
        rec_res, elapse = textRecognizer(imagesOfLine)

        # remove punctuations in rec_res
        punc = punctuation + string.punctuation

        print("\033[31m----------image num : {}----------\033[0m\nline num: {}, rec_res num  : {}, elapse : {}".format(i, len(rec_res), len(imagesOfLine), elapse))
        # caculate edit distance and accuracy of ocr
        assert len(imagesOfLine) == len(rec_res)
        for o,res,o_nopunc,imgname,img in zip(textOfLine, rec_res, text_without_punc, line_image_name_list, imagesOfLine):
            #原始输出正确率
            gt = o   
            pre = res[0]  
            dist += Levenshtein.distance(gt, pre)
            total_char_num += len(o)
            print("line image: ", imgname)
            print("\t\033[33morig text: {}\033[0m\n\t\033[34mpred text: {}\033[0m".format(gt, pre))
            print()

            #可视化
            # if(Levenshtein.distance(gt, pre) / float(len(o)) > 0.2):
            #     cv2.namedWindow(o)
            #     cv2.imshow(o, img)
            #     cv2.waitKey()
            #     cv2.destroyWindow(o)
            
            #去除空格正确率
            gt_without_space = ''.join(gt.split())
            pre_without_space = ''.join(pre.split())
            dist_without_space += Levenshtein.distance(gt_without_space, pre_without_space)
            print("\t\033[33morig text(without space): {}\033[0m\n\t\033[34mpred text(without space): {}\033[0m".format(gt_without_space, pre_without_space))
            
            #去除标点正确率
            gt_nopunc = o_nopunc 
            pre_without_punc = re.sub(r"[%s]+" %(punc), " ", pre)
            pre_without_punc = ' '.join(pre_without_punc.split())
            dist_without_punc += Levenshtein.distance(gt_nopunc, pre_without_punc)
            total_char_num_without_punc += len(o_nopunc)
            print("\t\033[33morig text(without punc): {}\033[0m\n\t\033[34mpred text(without punc): {}\033[0m".format(o_nopunc, pre_without_punc))
            
    print("***********Edit Distance: {}, Total Char Num: {}, Accuracy: {}**********".format(dist, total_char_num, 1-dist/total_char_num))
    print("***********Edit Distance(without space): {}, Total Char Num: {}, Accuracy: {}**********".format(dist_without_space, total_char_num, 1-dist_without_space/total_char_num))
    print("***********Edit Distance(without punc): {}, Total Char Num: {}, Accuracy: {}**********".format(dist_without_punc, total_char_num_without_punc, 1-dist_without_punc/total_char_num_without_punc))

